// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_WEIGHT_BUFFER_H
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_WEIGHT_BUFFER_H

#include <mc_scverify.h>

#include "src/AccelTypes.h"
#include "src/Deserializer.h"

template <int NROWS, int NCOLS>
class WeightFetcher {
 public:
  WeightFetcher() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(ac_channel<int> &addressRequest,
                      ac_channel<Params> &paramsIn) {
    Params params = paramsIn.read();

    for (int p2 = 0; p2 < params.P2; p2++) {
      for (int m1 = 0; m1 < params.M1; m1++) {
        for (int n1 = 0; n1 < params.N1; n1++) {
          for (int p1 = 0; p1 < params.P1; p1++) {
            for (int n0 = 0; n0 < NROWS; n0++) {
              for (int p0 = 0; p0 < NCOLS; p0++) {
                int address;
                if (params.TRANSPOSE) {
                  address = (p2 * params.P1 * NCOLS + p1 * NCOLS + p0) *
                                (NROWS * params.N1) +
                            (n0 + n1 * NROWS);
                  // LOG(address);
                } else {
                  address =
                      (n0 + n1 * NROWS) * (params.P1 * params.P2 * NCOLS) +
                      (p2 * params.P1 * NCOLS + p1 * NCOLS + p0);
                  // LOG(address);
                }

                address = params.WEIGHT_OFFSET + address;
                addressRequest.write(address);
              }
            }
          }
        }
      }
    }
  }
};

template <typename WDTYPE, int NROWS, int NCOLS, int BUFFER_SIZE>
class WeightWriter {
 public:
  WeightWriter() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(
      ac_channel<Pack1D<WDTYPE, NCOLS> > &dataResponse,
      ac_channel<chanStruct<Pack1D<WDTYPE, NCOLS>, BUFFER_SIZE> > &weightBuffer,
      ac_channel<Params> &paramsIn) {
    Params params = paramsIn.read();

    for (int p2 = 0; p2 < params.P2; p2++) {
      for (int m1 = 0; m1 < params.M1; m1++) {
        for (int n1 = 0; n1 < params.N1; n1++) {
          chanStruct<Pack1D<WDTYPE, NCOLS>, BUFFER_SIZE> buffer;
          for (int p1 = 0; p1 < params.P1; p1++) {
            for (int n0 = 0; n0 < NROWS; n0++) {
              Pack1D<WDTYPE, NCOLS> data = dataResponse.read();

              buffer.data[n0 * params.P1 + p1] = data;
            }
          }
          weightBuffer.write(buffer);
        }
      }
    }
  }
};

template <typename WDTYPE, int NROWS, int NCOLS, int BUFFER_SIZE>
class WeightReader {
 public:
  WeightReader() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(
      ac_channel<chanStruct<Pack1D<WDTYPE, NCOLS>, BUFFER_SIZE> > &weightBuffer,
      ac_channel<Pack1D<WDTYPE, NCOLS> > &weightsToSystolicArray,
      ac_channel<Params> &paramsIn) {
    Params params = paramsIn.read();

    for (int p2 = 0; p2 < params.P2; p2++) {
      for (int m1 = 0; m1 < params.M1; m1++) {
        for (int n1 = 0; n1 < params.N1; n1++) {
          chanStruct<Pack1D<WDTYPE, NCOLS>, BUFFER_SIZE> buffer =
              weightBuffer.read();
          for (int p1 = 0; p1 < params.P1; p1++) {
              // loaded in reverse order
              for (int n0 = NROWS - 1; n0 >= 0; n0--) {
                weightsToSystolicArray.write(buffer.data[n0 * params.P1 + p1]);
              }
          }
        }
      }
    }
  }
};

template <typename WDTYPE, int NROWS, int NCOLS, int BUFFER_SIZE>
class WeightBuffer {
 public:
  WeightBuffer() {}

#pragma hls_design interface
  void CCS_BLOCK(run)(
      ac_channel<int> &addressRequest, ac_channel<WDTYPE> &serialDataResponse,
      ac_channel<Pack1D<WDTYPE, NCOLS> > &weightsToSystolicArray,
      ac_channel<Params> &paramsIn) {
#ifndef __SYNTHESIS__
    while (paramsIn.available(1))
#endif
    {
      Params params = paramsIn.read();
      weightFetcherParams.write(params);
      weightWriterParams.write(params);
      weightReaderParams.write(params);

      weightFetcher.run(addressRequest, weightFetcherParams);
      weightDeserializer.run(serialDataResponse, parallelDataResponse);
      weightWriter.run(parallelDataResponse, weightBuffer, weightWriterParams);
      weightReader.run(weightBuffer, weightsToSystolicArray,
                       weightReaderParams);
    }
  }

 private:
  ac_channel<chanStruct<Pack1D<WDTYPE, NCOLS>, BUFFER_SIZE> > weightBuffer;

  WeightFetcher<NROWS, NCOLS> weightFetcher;
  ac_channel<Params> weightFetcherParams;

  Deserializer<WDTYPE, NCOLS> weightDeserializer;
  ac_channel<Pack1D<WDTYPE, NCOLS> > parallelDataResponse;

  WeightWriter<WDTYPE, NROWS, NCOLS, BUFFER_SIZE> weightWriter;
  ac_channel<Params> weightWriterParams;

  WeightReader<WDTYPE, NROWS, NCOLS, BUFFER_SIZE> weightReader;
  ac_channel<Params> weightReaderParams;
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_WEIGHT_BUFFER_H
